import math
import random
from builtins import Exception
from itertools import chain
from typing import Sequence, Any, Tuple, List

from gym import spaces as spaces

from centralized_verification.MultiAgentAPEnv import MultiAgentSafetyEnv
from centralized_verification.envs.utils import add_posn


def flatten(iterator_to_flatten):
    return list(chain.from_iterable(iterator_to_flatten))


def posn_floor(p):
    px, py = p
    return math.floor(px), math.floor(py)


def posn_dist(p1, p2):
    return math.sqrt(((p1[0] - p2[0]) ** 2) + ((p1[1] - p2[1]) ** 2))


class ContinuousGridWorld(MultiAgentSafetyEnv):
    """
    A grid world with many pre-computed properties for extremely fast steps
    (well, as fast as you can get with python)

    Each agent has a position and a direction. It has the following cone of visibility
    (assuming the agent is facing up):

    The observations are as such:
    0 = Empty
    1 = Filled with another agent
    2 = Wall

    This environment generates a single AP, representing if any two agents have collided with each other.

    Agents have five actions:
    0 = Do nothing
    1 = Move up
    2 = Right
    3 = Down
    4 = Left
    """

    def __init__(self, grid_posns, num_agents, start_idx, ending_idx, randomize_starts: bool = False,
                 collision_reward=-30, agents_bounce: bool = False, terminate_on_collision: bool = False):

        # Want to have same constructor interface as regular gridworld, but can't support all of the params
        if agents_bounce:
            raise Exception("Agents bounce not supported on continuous gridworld")
        if num_agents != 2:
            raise Exception("Continuous gridworld must have 2 agents")

        self.grid_posns = grid_posns
        self.num_agents = num_agents
        self.width = max(pos[0] for pos in self.grid_posns)
        self.height = max(pos[1] for pos in self.grid_posns)

        # posn -> idx
        self.grid_posn_inv = {pos: idx for idx, pos in enumerate(self.grid_posns)}

        assert len(start_idx) == num_agents
        assert len(ending_idx) == num_agents

        self.start_idx = start_idx
        self.goal_idx = ending_idx

        self.randomize_starts = randomize_starts
        self.collision_cost = collision_reward
        self.terminate_on_collision = terminate_on_collision

    def ap_names(self) -> List[str]:
        return ["collision"]

    def agent_obs_spaces(self) -> Sequence[spaces.Space]:
        return [spaces.Box(-1, 1, shape=(2 * self.num_agents,))] * self.num_agents

    def agent_actions_spaces(self) -> Sequence[spaces.Space]:
        return [spaces.Discrete(5)] * self.num_agents

    def state_space(self) -> spaces.Space:
        return spaces.Box(0, max(self.width, self.height), shape=(2 * self.num_agents,))

    def initial_state(self):
        if self.randomize_starts:
            starting_loc_0 = self.grid_posns[random.choice(range(len(self.grid_posns)))]
            good_locs = [p for p in self.grid_posns if posn_dist(starting_loc_0, p) >= 2]
            starting_loc_1 = random.choice(good_locs)
        else:
            starting_loc_0 = self.grid_posns[self.start_idx[0]]
            starting_loc_1 = self.grid_posns[self.start_idx[1]]

        starting_space_in_loc = .5, .5

        starting_loc_0 = add_posn(starting_loc_0, starting_space_in_loc)
        starting_loc_1 = add_posn(starting_loc_1, starting_space_in_loc)

        starting_state = *starting_loc_0, *starting_loc_1

        return starting_state, self.project_obs(starting_state)

    def get_next_loc(self, loc, act):
        dist_to_move = random.random()
        actions = [(0, 0),
                   (0, -dist_to_move),
                   (dist_to_move, 0),
                   (0, dist_to_move),
                   (-dist_to_move, 0)]

        next_loc = add_posn(loc, actions[act])
        grid_loc = posn_floor(next_loc)
        if grid_loc in self.grid_posn_inv:
            return next_loc
        else:
            return loc

    def step(self, environment_state, joint_action: Sequence[Any]) -> Tuple[
        Any, Sequence[Any], Sequence[float], bool, bool]:

        agent_0_loc = environment_state[0:2]
        agent_1_loc = environment_state[2:4]
        agent_0_new_loc = self.get_next_loc(agent_0_loc, joint_action[0])
        agent_1_new_loc = self.get_next_loc(agent_1_loc, joint_action[1])

        collisions_or_crossings = posn_dist(agent_0_new_loc, agent_1_new_loc) <= 1

        reached_goal = all(posn_floor(loc) == self.grid_posns[goal] for loc, goal in
                           zip((agent_0_new_loc, agent_1_new_loc), self.goal_idx))
        done = reached_goal

        def rew_for_agent(loc, old_loc, action):
            if reached_goal:
                return 100
            elif loc == old_loc and action != 0:  # Hit a wall
                return -10
            else:
                return -1

        rewards = [rew_for_agent(agent_0_loc, agent_0_new_loc, joint_action[0]),
                   rew_for_agent(agent_1_loc, agent_1_new_loc, joint_action[1])]

        if collisions_or_crossings:
            rewards = [rew + self.collision_cost for rew in rewards]

            if self.terminate_on_collision:
                done = True

        new_env_state = *agent_0_new_loc, *agent_1_new_loc
        return new_env_state, self.project_obs(new_env_state), rewards, done, (not collisions_or_crossings)

    def project_obs(self, state) -> Sequence[Any]:
        obs = tuple(((num * 2) / div) - 1 for num, div in zip(state, [self.width, self.height] * self.num_agents))
        return tuple([obs] * self.num_agents)
